import json

from transformers import AutoModelForSeq2SeqLM, DataCollatorWithPadding
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments
from tqdm import tqdm
import pandas as pd
import numpy
import random
import nevergrad as ng
from peft.utils.save_and_load import set_peft_model_state_dict, get_peft_model_state_dict
from transformers.generation.configuration_utils import GenerationConfig
from peft import PeftModel, PeftConfig
from functools import partial
from typing import *
import copy
import re

def get_prompt_ds(paths:str):
    prompt = []
    output = []

    js = json.load(open(paths))
    for i in js:
        prompt.append(i["prompt"])
        output.append(i["output"])
    return prompt, output

def load_base_model_and_lora_modules(lora_module_list: List[str], model_name_or_path: Optional[str] = None):
    """load base model and lora modules from huggingface model hub

    Args:
        lora_module_list (List[str]): a list of lora module names available in huggingface model hub
        model_name_or_path (Optional[str]): base model name, default is None
    """
    # use gpu if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # load basic model
    default_peft_model_id = lora_module_list[0]
    # find the base model
    if model_name_or_path is None:
        model_name_or_path = PeftConfig.from_pretrained(default_peft_model_id).base_model_name_or_path

    base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, torch_dtype = torch.bfloat16)
    # load tokenizer
    tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name_or_path)
    # 0 is the default model
    try:
        peft_model = PeftModel.from_pretrained(base_model, default_peft_model_id)
    except:
        raise Exception(f'{default_peft_model_id} is unable to load into the model {model_name_or_path}')

    peft_model = peft_model.to(device)
    peft_model.eval()

    print("> Begin to load lora modules")
    cache = {}

    first_dict = None

    for id, peft_model_id in tqdm(enumerate(lora_module_list)):
        print("> Loading {} ...".format(peft_model_id))
        cur_peft_model = PeftModel.from_pretrained(base_model, peft_model_id)
        cache[peft_model_id] = copy.deepcopy(get_peft_model_state_dict(cur_peft_model))

        if first_dict is None:
            first_dict = cache[peft_model_id]
        # check whether the LoRA can be merged into one
        try:
            # detect whether the arch is the same
            for key in first_dict.keys():
                assert first_dict[key].shape == cache[peft_model_id][key].shape
        except:
            raise Exception(f'LoRA Modules {peft_model_id} cannot be merged since it has a different arch (e.g., rank).')

    return peft_model, tokenizer, cache

def preprocess_function(examples, tokenizer):
    """
    standard preprocess function for dataset
    """
    inputs = examples["input"]
    targets = examples["output"]
    model_inputs = tokenizer(
        inputs,
        max_length=2048,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    labels = tokenizer(
        targets,
        max_length=2048,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


def load_dataset(example_inputs, example_outputs, tokenizer):
    # add empty string if example_outputs is None
    if example_outputs is None:
        example_outputs = [""] * len(example_inputs)
    df = [
        {"input": example_inputs[i], "output": example_outputs[i]}
        for i in range(len(example_inputs))
    ]
    dataset = Dataset.from_pandas(pd.DataFrame(df))
    preprocess_func_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer)
    processed_datasets = dataset.map(
        preprocess_func_with_tokenizer,
        batched=True,
        num_proc=1,
        desc="Running tokenizer on dataset",
    )
    return processed_datasets


def default_get_loss(example_dataset, model, tokenizer, batch_size):
    """
    Get the loss of the model on the example dataset. Usually the example dataset only contains a few examples.
    """
    data_batch_size = len(example_dataset) if batch_size is None else min(len(example_dataset), batch_size)
    # use gpu if available
    collen_fn = DataCollatorForSeq2Seq(
        model=model,
        tokenizer=tokenizer,
        padding="longest",
        max_length=2048,
        return_tensors="pt"
    )
    train_dataloader = DataLoader(
        example_dataset,
        collate_fn=collen_fn,
        batch_size=data_batch_size,
        pin_memory=True,
    )
    train_loss = 0
    with torch.no_grad():
        device = "cuda" if torch.cuda.is_available() else "cpu"
        for _, batch in tqdm(enumerate(train_dataloader)):
            # print(batch)
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            loss = outputs.loss
            train_loss += loss.detach().float()
    loss = train_loss.float()
    # average loss over the number of examples
    return float(loss) / len(example_dataset["input_ids"])

def default_l1_regularization(weights):
    """
    Get the L1 regularization term for the weights
    """
    sum_of_squares = sum([abs(x) for x in weights]) / len(weights)
    return 0.05 * sum_of_squares

def get_score(weights, model, tokenizer,  cache, example_dataset, batch_size, get_loss, get_regular):
    # the composed lora state dict
    final_state_dict = {}
    # module list is the list
    lora_module_list = list(cache.keys())
    # all keys are the same
    keys = cache[lora_module_list[0]].keys()
    for i, peft_model_id in enumerate(lora_module_list):
        lora_state_dict = cache[peft_model_id]
        if i == 0:
            for key in keys:
                final_state_dict[key] = weights[i] * lora_state_dict[key]
        else:
            for key in keys:
                final_state_dict[key] = (
                        final_state_dict[key] + weights[i] * lora_state_dict[key]
                )
    # reload the model with the new adapter config
    set_peft_model_state_dict(model, final_state_dict)

    # minimize the metric
    loss = get_loss(example_dataset, model, tokenizer, batch_size)
    # L1 regularization term
    metric_val = loss + get_regular(weights)

    return metric_val

def get_final_weights(weights, lora_module_list, cache):
    final_state_dict = {}
    keys = cache[lora_module_list[0]].keys()
    for i, peft_model_id in enumerate(lora_module_list):
        lora_state_dict = cache[peft_model_id]
        if i == 0:
            for key in keys:
                final_state_dict[key] = weights[i] * lora_state_dict[key]
        else:
            for key in keys:
                final_state_dict[key] = (
                        final_state_dict[key] + weights[i] * lora_state_dict[key]
                )
    return final_state_dict

def lorahub_inference(example_inputs: List[str],
                      model_or_name_path: Union[AutoModelForSeq2SeqLM, str],
                      tokenizer_or_tokenizer_path: Union[AutoTokenizer, str],
                      batch_size: int,
                      # if not provided, we do not report the accuracy
                      example_outputs: List[str]=None):

    def accuracy_score(outputs, ground_truths):
        correct = 0
        total = 0
        for output, truth in zip(outputs, ground_truths):
            if output.strip().lower().replace(".", "") == truth.strip().lower().replace(".", ""):
                correct += 1
            total += 1
        return correct / total * 100

    example_predictions = []
    # load model
    if isinstance(model_or_name_path, str):
        model = AutoModelForSeq2SeqLM.from_pretrained(model_or_name_path)
    else:
        model = model_or_name_path

    # load tokenizer
    if isinstance(tokenizer_or_tokenizer_path, str):
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_or_tokenizer_path)
    else:
        tokenizer = tokenizer_or_tokenizer_path

    # process dataset
    dataset = load_dataset(example_inputs, example_outputs, tokenizer)
    # use gpu if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    for i in range(0, len(dataset["input"]), batch_size):
        inputs = tokenizer(
            dataset["input"][i : i + batch_size],
            max_length=2048,
            return_tensors="pt",
            padding=True,
        ).to(device)
        outputs = model.generate(
            input_ids=inputs["input_ids"], max_new_tokens=256
        )
        outputs = tokenizer.batch_decode(
            outputs.to("cpu"), skip_special_tokens=True
        )
        example_predictions.extend(outputs)

    if example_outputs is not None:
        task_perf = accuracy_score(example_predictions, example_outputs)
    else:
        task_perf = None

    return example_predictions, task_perf


def re_calculate_metrics(predicted: List[str], actual: List[str]):
    # 将字符串转换为元组
    def re_parse_tuple_string(tuple_string: str) -> List[Tuple[str, str, str]]:
        # 使用正则表达式匹配元组，使第二个和第三个元素可选
        tuple_string = tuple_string.replace('，',',').replace('（','(').replace('）', ')').replace("：", ":").replace(" ", "")
        pattern = re.compile(r'\(([^,]*),? ?([^,]*)?,? ?([^)]*)?\)')
        tuples = pattern.findall(tuple_string)
        # 将空元组替换为特殊元组
        tuples = [('EMPTY', 'EMPTY', 'EMPTY') if t == ('', '', '') else t for t in tuples]
        return tuples

    predicted = [re_parse_tuple_string(t) for t in predicted]
    actual = [re_parse_tuple_string(t) for t in actual]

    # 将每个列表中的元组抽取出来
    predicted_tuples = [item for sublist in predicted for item in sublist]
    actual_tuples = [item for sublist in actual for item in sublist]

    tp = len(set(predicted_tuples) & set(actual_tuples))
    fp = len(set(predicted_tuples) - set(actual_tuples))
    fn = len(set(actual_tuples) - set(predicted_tuples))

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return {"precision": precision, "recall": recall, "f1": f1}

def ner_calculate_metrics(predicted: List[str], actual: List[str]):
    # 将字符串转换为元组
    def ner_parse_tuple_string(tuple_string: str) -> List[Tuple[str, str]]:
        # 使用正则表达式匹配元组，使第二个元素可选
        tuple_string = tuple_string.replace('，',',').replace('（','(').replace('）', ')').replace("：", ":").replace(" ", "")
        pattern = re.compile(r'\(([^,]*),? ?([^)]*)?\)')
        tuples = pattern.findall(tuple_string)
        # 将空元组替换为特殊元组
        tuples = [('EMPTY', 'EMPTY') if t == ('', '') else t for t in tuples]
        return tuples

    predicted = [ner_parse_tuple_string(t) for t in predicted]
    actual = [ner_parse_tuple_string(t) for t in actual]

    # 将每个列表中的元组抽取出来
    predicted_tuples = [item for sublist in predicted for item in sublist]
    actual_tuples = [item for sublist in actual for item in sublist]

    tp = len(set(predicted_tuples) & set(actual_tuples))
    fp = len(set(predicted_tuples) - set(actual_tuples))
    fn = len(set(actual_tuples) - set(predicted_tuples))

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return {"precision": precision, "recall": recall, "f1": f1}

def ee_calculate_metrics(predicted: List[str], actual: List[str]):
    def ee_parse_tuple_string(tuple_string: str) -> List[Tuple[str, str, List[Tuple[str, str]]]]:
        def parse_tuple_string(tuple_string: str):
            # 使用正则表达式匹配元组，使第二个元素可选
            tuple_string = tuple_string.replace('，',',').replace('（','(').replace('）', ')').replace("：", ":").replace(" ", "")
            pattern = re.compile(r'\(([^,]*),? ?([^)]*)?\)')
            tuples = pattern.findall(tuple_string)
            # 将空元组替换为特殊元组
            tuples = tuple([('EMPTY', 'EMPTY') if t == ('', '') else t for t in tuples])
            return tuples
        # 使用正则表达式匹配元组，使第二个和第三个元素可选
        tuple_string = tuple_string.replace('，',',').replace('（','(').replace('）', ')').replace("：", ":").replace(" ", "")
        pattern = re.compile(r'\(([^,]*),? ?([^,]*)?,? ?(\[.*\])?\)')
        tuples = pattern.findall(tuple_string)
        # 将空元组替换为特殊元组
        tuples = [('EMPTY', 'EMPTY', [('EMPTY', 'EMPTY')]) if t == ('', '', '') else (t[0], t[1], tuple(sorted(parse_tuple_string(t[2])))) for t in tuples]
        return tuples
    # 将字符串转换为元组
    predicted = [ee_parse_tuple_string(t) for t in predicted]
    actual = [ee_parse_tuple_string(t) for t in actual]

    # 将每个列表中的元组抽取出来
    predicted_tuples = [item for sublist in predicted for item in sublist]
    actual_tuples = [item for sublist in actual for item in sublist]

    tp = len(set(predicted_tuples) & set(actual_tuples))
    fp = len(set(predicted_tuples) - set(actual_tuples))
    fn = len(set(actual_tuples) - set(predicted_tuples))

    precision = tp / (tp + fp) if (tp + fp) > 0 else 1
    recall = tp / (tp + fn) if (tp + fn) > 0 else 1
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 1

    return {"precision": precision, "recall": recall, "f1": f1}

def assess_task_performance(
        example_inputs: List[str],
        model_or_name_path: Union[AutoModelForSeq2SeqLM, str],
        tokenizer_or_tokenizer_path: Union[AutoTokenizer, str],
        batch_size: int,
        # if not provided, we do not report the accuracy
        example_outputs: List[str]=None,
        task: str = "ner",
):
    example_predictions = []
    # load model
    if isinstance(model_or_name_path, str):
        model = AutoModelForSeq2SeqLM.from_pretrained(model_or_name_path)
    else:
        model = model_or_name_path

    # load tokenizer
    if isinstance(tokenizer_or_tokenizer_path, str):
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_or_tokenizer_path)
    else:
        tokenizer = tokenizer_or_tokenizer_path

    generation_config = GenerationConfig()
    generation_config.remove_invalid_values = True
    generation_config.eos_token_id = tokenizer.eos_token_id
    generation_config.pad_token_id = tokenizer.pad_token_id
    generation_config.decoder_start_token_id = tokenizer.pad_token_id
    generation_config.max_new_tokens = 512
    generation_config.repetition_penalty = 1.5
    generation_config.num_beams = 1         # greedy search
    generation_config.do_sample = False

    # process dataset
    dataset = load_dataset(example_inputs, example_outputs, tokenizer)
    # use gpu if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    for i in tqdm(range(0, len(dataset["input"]), batch_size)):
        inputs = tokenizer(
            dataset["input"][i : i + batch_size],
            max_length=2048,
            return_tensors="pt",
            padding=True,
        ).to(device)
        outputs = model.generate(
            input_ids=inputs["input_ids"], max_new_tokens=256, generation_config=generation_config
        )
        outputs = tokenizer.batch_decode(
            outputs.to("cpu"), skip_special_tokens=True
        )
        example_predictions.extend(outputs)

    ass_f = ""
    if task == "ner": ass_f = ner_calculate_metrics
    elif task == "ee" : ass_f = ee_calculate_metrics
    elif task == "re" : ass_f = re_calculate_metrics

    if example_outputs is not None:
        task_perf = ass_f(example_predictions, example_outputs)
    else:
        task_perf = None

    return example_predictions, task_perf

def lorahub_learning(lora_module_list: List[str],
                     example_inputs: List[str],
                     example_outputs: List[str],
                     max_inference_step: int,
                     model_name_or_path=None,
                     batch_size=None,
                     get_loss=default_get_loss,
                     get_regular=default_l1_regularization,
                     seed=42):
    # set seed for reproducibility
    random.seed(seed)
    numpy.random.seed(seed)

    number_of_loras = len(lora_module_list)
    if number_of_loras == 0:
        print("> No LoRA modules are provided. Please provide at least one LoRA module.")
        return None, None

    # load model
    model, tokenizer, cache = load_base_model_and_lora_modules(lora_module_list, model_name_or_path)
    model.to("cuda")
    # process dataset
    dataset = load_dataset(example_inputs, example_outputs, tokenizer)
    # if "token_type_ids" in dataset.column_names:
    dataset = dataset.remove_columns(["token_type_ids","input", "output"])
    # dataset.select_columns(['input', 'input_ids', 'attention_mask', 'labels'])
    print(dataset)

    get_score_partial = partial(get_score,
                                model=model,
                                tokenizer = tokenizer,
                                cache=cache,
                                example_dataset=dataset,
                                batch_size=batch_size,
                                get_loss=get_loss,
                                get_regular=get_regular)
    # set up the limit of the weights
    instrum = ng.p.Array(
        init= [0] + [0] * (number_of_loras - 1),
        upper=[1.5] * number_of_loras,
        lower=[-1.5] * number_of_loras,
    )
    optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=max_inference_step)
    print("> Begin to perform gradient-free optimization ...")
    recommendation = optimizer.minimize(get_score_partial, verbosity=1)
    final_lora = get_final_weights(recommendation.value, lora_module_list, cache)
    # set the final weights
    set_peft_model_state_dict(model, final_lora)
    model = model.merge_and_unload()
    return recommendation.value, model, tokenizer